package v1

import (
	"context"
	"errors"
	"gitee.com/zackeus/go-boot/common/constants/net/headers"
	"gitee.com/zackeus/go-boot/scs/v1/codec"
	"gitee.com/zackeus/go-boot/scs/v1/codec/gob"
	"gitee.com/zackeus/go-boot/scs/v1/store"
	"gitee.com/zackeus/go-zero/core/logx"
	"gitee.com/zackeus/goutil"
	"gitee.com/zackeus/goutil/arrutil"
	"net/http"
	"time"
)

type (
	// SessionOption session 可选项
	SessionOption func(s *SessionManager)

	// SessionManager session 管理器
	SessionManager struct {
		// contextKey 是用于从 context.Context 设置和检索会话数据的键 它是自动生成的以确保唯一性
		contextKey contextKey

		// idleTimeout 非活跃最大空闲时长 设置后每次请求都会更新 cookie, 默认不设置
		idleTimeout time.Duration

		// lifetime session 有效时长 为绝对日期 并且不会更新
		// 一般 lifetime 和 idleTimeout 搭配使用, 如 idleTimeout 设为 30min, lifetime 设为 24h, 表示session在24h后过期, 30min中内操作可刷新
		lifetime time.Duration

		// cookie cookie
		cookie SessionCookie

		// store session 存储
		store store.Store

		// codec session 编码器解码器 默认使用 gob 进行编码解码
		codec codec.Codec

		// 白名单
		whiteList []string

		// 自定义 context 处理
		contextHandle ContextHandle
		// unauthorizedCallback 未鉴权回调函数 默认行为是 返回 401 header
		unauthorizedCallback UnauthorizedCallback
		// errorCallback 异常回调函数 默认行为是将 HTTP 500 'Internal Server Error' 消息发送到客户端，并使用 Go 的标准记录器记录错误
		errorCallback ErrorCallback
	}

	// SessionCookie cookie 配置
	SessionCookie struct {
		// name cookie 名称
		name string

		// domain cookie domain
		domain string

		// httpOnly cookie HttpOnly, 默认为 true
		httpOnly bool

		// path cookie path
		path string

		// persist 设置会话 cookie 是否应该持久(即是否应该在用户关闭浏览器后保留) 默认值为 true
		// 这意味着会话 cookie 不会在用户关闭浏览器时被销毁，并且适当的“Expires”和“MaxAge”值将添加到会话 cookie
		// 如果您只想保留某些会话(而不是所有会话) 则将其设置为 false 并为您要保留的特定会话调用 RememberMe() 方法。
		persist bool

		// sameSite 控制会话 cookie 上 'SameSite' 属性的值。默认为“SameSite=Lax”
		// 如果不想在会话 cookie 中使用 SameSite 属性或值，则应将其设置为 0
		// http.SameSiteLaxMode: Cookie 只能在同一站点的请求中被发送，但是可以在导航到站点的跨站点GET请求中发送(例如从其他网站链接到该站点)
		// http.SameSiteStrictMode: 禁止跨站点发送cookie
		sameSite http.SameSite

		// secure 严格模式 只能在 https 中传输
		// See https://github.com/OWASP/CheatSheetSeries/blob/master/cheatsheets/Session_Management_Cheat_Sheet.md#transport-layer-security.
		secure bool
	}
)

// New 构建 SessionManager
func New(name string, lifetime time.Duration, store store.Store, options ...SessionOption) (*SessionManager, error) {
	s := &SessionManager{
		contextKey:  generateContextKey(),
		idleTimeout: 0,
		lifetime:    lifetime,
		store:       store,
		codec:       gob.GobCodec{},
		whiteList:   []string{},
		cookie: SessionCookie{
			name:     name,
			domain:   "",
			httpOnly: true,
			path:     "/",
			persist:  true,
			secure:   false,
			sameSite: http.SameSiteLaxMode,
		},
		errorCallback:        defaultErrorCallback,
		unauthorizedCallback: nil,
	}

	for _, option := range options {
		option(s)
	}
	if err := s.validate(); err != nil {
		return nil, err
	}

	return s, nil
}

// 参数校验
func (s *SessionManager) validate() error {
	if 0 >= s.lifetime {
		return errors.New("the lifetime must be greater than 0")
	}
	if s.idleTimeout > s.lifetime {
		return errors.New("the idleTimeout must be less than lifetime")
	}
	return nil
}

// HandlerFunc 提供自动加载和保存当前请求的会话数据的中间件，并在 cookie 中与客户端通信会话令牌
func (s *SessionManager) HandlerFunc(next http.HandlerFunc) http.HandlerFunc {
	return s.Handle(next)
}

// Handle 提供自动加载和保存当前请求的会话数据的中间件，并在 cookie 中与客户端通信会话令牌
func (s *SessionManager) Handle(next http.Handler) http.HandlerFunc {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		/* session 校验白名单 */
		if arrutil.In(r.URL.Path, s.whiteList) {
			next.ServeHTTP(w, r)
			return
		}

		/* 解析 cookie */
		var token string
		cookie, err := s.Cookie(r)
		if goutil.IsEqual(http.ErrNoCookie, err) {
			/* cookie 不存在 返回401 */
			s.onUnauthorizedCallback(w, r, next)
			return
		}
		if err != nil {
			s.errorCallback(w, r, err)
			return
		}
		token = cookie.Value

		/* 从指定 cookie token 加载 session */
		ctx, err := s.Load(r.Context(), token)
		if goutil.IsEqual(SessionNotFound, err) {
			/* session not found 返回401 */
			s.onUnauthorizedCallback(w, r, next)
			return
		}
		if err != nil {
			s.errorCallback(w, r, err)
			return
		}

		if s.contextHandle != nil {
			uid, err := s.Uid(ctx)
			if err != nil {
				s.errorCallback(w, r, err)
				return
			}
			/* 自定义 ctx */
			ctx = s.contextHandle(ctx, uid, token)
		}
		/* request 处理 */
		sr := r.WithContext(ctx)
		sw := &sessionResponseWriter{
			ResponseWriter: w,
			request:        sr,
			sessionManager: s,
		}
		next.ServeHTTP(sw, sr)

		if !sw.written {
			s.commitAndWriteSessionCookie(w, sr)
		}
	})
}

func (s *SessionManager) commitAndWriteSessionCookie(w http.ResponseWriter, r *http.Request) {
	ctx := r.Context()
	status, err := s.Status(ctx)
	if err != nil {
		s.errorCallback(w, r, err)
		return
	}

	switch status {
	case Modified:
		/* session 已修改 提交 */
		token, expiry, err := s.Commit(ctx)
		if err != nil {
			s.errorCallback(w, r, err)
			return
		}
		err = s.WriteSessionCookie(ctx, w, token, expiry)
		if err != nil {
			s.errorCallback(w, r, err)
			return
		}
	case Destroyed:
		/* session 已销毁 */
		err := s.WriteSessionCookie(ctx, w, "", time.Time{})
		if err != nil {
			s.errorCallback(w, r, err)
			return
		}
	}
}

// Cookie 获取 cookie
func (s *SessionManager) Cookie(r *http.Request) (*http.Cookie, error) {
	return r.Cookie(s.cookie.name)
}

// WriteSessionCookie 使用提供的令牌作为 cookie 值，将 cookie 写入 HTTP 响应，并将 expiry 作为 cookie 过期时间
// 仅当会话设置为持续或已调用 RememberMe(true) 时，到期时间才会包含在 cookie 中
// 如果 expiry 是一个空的 time.Time 结构（以便它的 IsZero() 方法返回 true, cookie 将被标记为历史到期时间和负的 max-age（因此浏览器将其删除
func (s *SessionManager) WriteSessionCookie(ctx context.Context, w http.ResponseWriter, token string, expiry time.Time) error {
	cookie := &http.Cookie{
		Name:     s.cookie.name,
		Value:    token,
		Path:     s.cookie.path,
		Domain:   s.cookie.domain,
		Secure:   s.cookie.secure,
		HttpOnly: s.cookie.httpOnly,
		SameSite: s.cookie.sameSite,
	}

	rm, err := s.GetBool(ctx, "__rememberMe")
	if err != nil {
		return err
	}

	if expiry.IsZero() {
		cookie.Expires = time.Unix(1, 0)
		cookie.MaxAge = -1
	} else if s.cookie.persist || rm {
		cookie.Expires = time.Unix(expiry.Unix()+1, 0)        // Round up to the nearest second.
		cookie.MaxAge = int(time.Until(expiry).Seconds() + 1) // Round up to the nearest second.
	}

	w.Header().Add(headers.Vary, "Cookie")
	w.Header().Add(headers.SetCookie, cookie.String())
	w.Header().Add(headers.CacheControl, `no-cache="Set-Cookie"`)
	return nil
}

// Generate session 生成
func (s *SessionManager) Generate(ctx context.Context, uid string) context.Context {
	return s.addSessionDataToContext(ctx, newSessionData(uid, s.lifetime))
}

// Load 从会话存储中检索给定令牌的会话数据，并返回包含会话数据的新上下文
func (s *SessionManager) Load(ctx context.Context, token string) (context.Context, error) {
	if _, ok := ctx.Value(s.contextKey).(*sessionData); ok {
		return ctx, nil
	}
	/* token 为空 */
	if goutil.IsEmpty(token) {
		return nil, SessionNotFound
	}

	b, found, err := s.doStoreGet(ctx, token)
	if err != nil {
		return nil, err
	} else if !found {
		/* session not found */
		return nil, SessionNotFound
	}

	sd := &sessionData{
		status: Unmodified,
		token:  token,
	}
	/* sessionData 解码 */
	if sd.uid, sd.deadline, sd.values, err = s.codec.Decode(b); err != nil {
		return nil, err
	}

	/* 如果设置了 IdleTimeout，则将会话数据标记为已修改。这将强制会话数据以新的到期时间重新提交到会话存储 */
	if s.idleTimeout > 0 {
		sd.status = Modified
	}
	/* 绑定 sessionData */
	return s.addSessionDataToContext(ctx, sd), nil
}

// Commit 将会话数据保存到会话存储并返回会话令牌和到期时间
func (s *SessionManager) Commit(ctx context.Context) (string, time.Time, error) {
	sd, err := s.getSessionDataFromContext(ctx)
	if err != nil {
		return "", time.Time{}, err
	}

	sd.mu.Lock()
	defer sd.mu.Unlock()

	/* token 初始化 */
	if goutil.IsEmpty(sd.token) {
		var err error
		if sd.token, err = generateToken(); err != nil {
			return "", time.Time{}, err
		}
	}
	/* 编码 */
	b, err := s.codec.Encode(sd.uid, sd.deadline, sd.values)
	if err != nil {
		return "", time.Time{}, err
	}
	/* 设置 session 会话周期 */
	expiry := sd.deadline
	if s.idleTimeout > 0 {
		ie := time.Now().Add(s.idleTimeout).UTC()
		if ie.Before(expiry) {
			expiry = ie
		}
	}
	/* store 提交 */
	if err := s.doStoreCommit(ctx, sd.token, b, expiry); err != nil {
		return "", time.Time{}, err
	}
	return sd.token, expiry, nil
}

// Destroy 从 ctx 会话存储中删除会话数据并将会话状态设置为已销毁
func (s *SessionManager) Destroy(ctx context.Context) error {
	sd, err := s.getSessionDataFromContext(ctx)
	if err != nil {
		return err
	}

	sd.mu.Lock()
	defer sd.mu.Unlock()

	err = s.doStoreDelete(ctx, sd.uid, sd.token)
	if err != nil {
		return err
	}

	sd.status = Destroyed

	// Reset everything else to defaults.
	sd.token = ""
	sd.deadline = time.Now().Add(s.lifetime).UTC()
	for key := range sd.values {
		delete(sd.values, key)
	}

	return nil
}

// Delete 删除指定 tokens
func (s *SessionManager) Delete(ctx context.Context, uid string, tokens []string) error {
	return s.doStoreDelete(ctx, uid, tokens...)
}

// Clear 根据 uid 列表删除数据
func (s *SessionManager) Clear(ctx context.Context, uids ...string) error {
	if uids == nil {
		return nil
	}
	return s.doStoreClear(ctx, uids)
}

// IsAlive 判断 token 是否存活
func (s *SessionManager) IsAlive(ctx context.Context, token string) (bool, error) {
	_, found, err := s.doStoreGet(ctx, token)
	if err != nil {
		return false, err
	} else if !found {
		return false, nil
	}
	return true, nil
}

// MergeSession 用于合并来自不同会话的数据，以防严格的会话令牌在 oauth 或类似的重定向流中丢失。
// 如果不使用新会话的值，请使用 ClearData()
func (s *SessionManager) MergeSession(ctx context.Context, token string) error {
	sd, err := s.getSessionDataFromContext(ctx)
	if err != nil {
		return err
	}

	b, found, err := s.doStoreGet(ctx, token)
	if err != nil {
		return err
	} else if !found {
		return nil
	}

	uid, deadline, values, err := s.codec.Decode(b)
	if err != nil {
		return err
	}

	sd.mu.Lock()
	defer sd.mu.Unlock()

	/* 如果是同一会话，则无需执行任何操作 */
	if goutil.IsEqual(sd.token, token) {
		return nil
	}

	sd.uid = uid
	if deadline.After(sd.deadline) {
		sd.deadline = deadline
	}

	for k, v := range values {
		sd.values[k] = v
	}

	sd.status = Modified
	return s.doStoreDelete(ctx, uid, token)
}

// 未鉴权回调
func (s *SessionManager) onUnauthorizedCallback(w http.ResponseWriter, r *http.Request, next http.Handler) {
	if s.unauthorizedCallback == nil {
		/* 返回 401 */
		w.WriteHeader(http.StatusUnauthorized)
		return
	}

	defer func() {
		if err := recover(); err != nil {
			logx.Error(err)
		}
	}()
	s.unauthorizedCallback(w, r, next)
}

func WithIdleTimeout(idleTimeout time.Duration) SessionOption {
	return func(s *SessionManager) {
		s.idleTimeout = idleTimeout
	}
}

func WithCodec(codec codec.Codec) SessionOption {
	return func(s *SessionManager) {
		s.codec = codec
	}
}

func WithDomain(domain string) SessionOption {
	return func(s *SessionManager) {
		s.cookie.domain = domain
	}
}

func WithHttpOnly(httpOnly bool) SessionOption {
	return func(s *SessionManager) {
		s.cookie.httpOnly = httpOnly
	}
}

func WithPath(path string) SessionOption {
	return func(s *SessionManager) {
		s.cookie.path = path
	}
}

func WithPersist(persist bool) SessionOption {
	return func(s *SessionManager) {
		s.cookie.persist = persist
	}
}

func WithSameSite(sameSite http.SameSite) SessionOption {
	return func(s *SessionManager) {
		s.cookie.sameSite = sameSite
	}
}

func WithSecure(secure bool) SessionOption {
	return func(s *SessionManager) {
		s.cookie.secure = secure
	}
}

// WithContext 自定义 ctx 会在 session 成功加载后调用
func WithContext(handle ContextHandle) SessionOption {
	return func(s *SessionManager) {
		s.contextHandle = handle
	}
}

// WithWhiteList 校验白名单
func WithWhiteList(url ...string) SessionOption {
	return func(s *SessionManager) {
		if url == nil {
			return
		}
		s.whiteList = arrutil.Unique(append(s.whiteList, url...))
	}
}

// WithErrorCallback 异常回调
func WithErrorCallback(callback ErrorCallback) SessionOption {
	return func(s *SessionManager) {
		s.errorCallback = callback
	}
}

// WithUnauthorizedCallback 未鉴权回调
func WithUnauthorizedCallback(callback UnauthorizedCallback) SessionOption {
	return func(s *SessionManager) {
		s.unauthorizedCallback = callback
	}
}

// 默认异常回调
func defaultErrorCallback(w http.ResponseWriter, r *http.Request, err error) {
	logx.Error(err)
	http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}

// sessionResponseWriter
type sessionResponseWriter struct {
	http.ResponseWriter
	request        *http.Request
	sessionManager *SessionManager
	written        bool
}

func (sw *sessionResponseWriter) Write(b []byte) (int, error) {
	if !sw.written {
		sw.sessionManager.commitAndWriteSessionCookie(sw.ResponseWriter, sw.request)
		sw.written = true
	}
	return sw.ResponseWriter.Write(b)
}

func (sw *sessionResponseWriter) WriteHeader(code int) {
	if !sw.written {
		sw.sessionManager.commitAndWriteSessionCookie(sw.ResponseWriter, sw.request)
		sw.written = true
	}
	sw.ResponseWriter.WriteHeader(code)
}

func (sw *sessionResponseWriter) Unwrap() http.ResponseWriter {
	return sw.ResponseWriter
}
